1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.integration.rsocket;
18
19 import java.lang.reflect.Method;
20 import java.nio.charset.StandardCharsets;
21 import java.util.Collections;
22 import java.util.HashMap;
23 import java.util.Map;
24 import java.util.function.BiFunction;
25
26 import org.springframework.context.ApplicationEventPublisher;
27 import org.springframework.context.ApplicationEventPublisherAware;
28 import org.springframework.core.io.buffer.DataBuffer;
29 import org.springframework.lang.Nullable;
30 import org.springframework.messaging.Message;
31 import org.springframework.messaging.MessageHeaders;
32 import org.springframework.messaging.handler.CompositeMessageCondition;
33 import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
34 import org.springframework.messaging.rsocket.RSocketRequester;
35 import org.springframework.messaging.rsocket.annotation.support.RSocketFrameTypeMessageCondition;
36 import org.springframework.messaging.rsocket.annotation.support.RSocketRequesterMethodArgumentResolver;
37 import org.springframework.util.Assert;
38 import org.springframework.util.ReflectionUtils;
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 public class ServerRSocketMessageHandler extends IntegrationRSocketMessageHandler
60 implements ApplicationEventPublisherAware {
61
62 private static final Method HANDLE_CONNECTION_SETUP_METHOD =
63 ReflectionUtils.findMethod(ServerRSocketMessageHandler.class, "handleConnectionSetup", Message.class);
64
65
66 private final Map<Object, RSocketRequester> clientRSocketRequesters = new HashMap<>();
67
68 private BiFunction<Map<String, Object>, DataBuffer, Object> clientRSocketKeyStrategy =
69 (headers, data) -> data.toString(StandardCharsets.UTF_8);
70
71 private ApplicationEventPublisher applicationEventPublisher;
72
73
74
75
76
77
78 public ServerRSocketMessageHandler() {
79 this(false);
80 }
81
82
83
84
85
86
87
88
89
90
91
92
93 public ServerRSocketMessageHandler(boolean messageMappingCompatible) {
94 super(messageMappingCompatible);
95 }
96
97
98
99
100
101
102 public void setClientRSocketKeyStrategy(
103 BiFunction<Map<String, Object>, DataBuffer, Object> clientRSocketKeyStrategy) {
104
105 Assert.notNull(clientRSocketKeyStrategy, "'clientRSocketKeyStrategy' must not be null");
106 this.clientRSocketKeyStrategy = clientRSocketKeyStrategy;
107 }
108
109
110
111
112
113
114 public Map<Object, RSocketRequester> getClientRSocketRequesters() {
115 return Collections.unmodifiableMap(this.clientRSocketRequesters);
116 }
117
118
119
120
121
122
123 @Nullable
124 public RSocketRequester getClientRSocketRequester(Object key) {
125 return this.clientRSocketRequesters.get(key);
126 }
127
128 @Override
129 public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
130 this.applicationEventPublisher = applicationEventPublisher;
131 }
132
133 void registerHandleConnectionSetupMethod() {
134 registerHandlerMethod(this, HANDLE_CONNECTION_SETUP_METHOD,
135 new CompositeMessageCondition(
136 RSocketFrameTypeMessageCondition.CONNECT_CONDITION,
137 new DestinationPatternsMessageCondition(new String[] { "*" }, obtainRouteMatcher())));
138 }
139
140 @SuppressWarnings("unused")
141 private void handleConnectionSetup(Message<DataBuffer> connectMessage) {
142 DataBuffer dataBuffer = connectMessage.getPayload();
143 MessageHeaders messageHeaders = connectMessage.getHeaders();
144 Object rsocketRequesterKey = this.clientRSocketKeyStrategy.apply(messageHeaders, dataBuffer);
145 RSocketRequester rsocketRequester =
146 messageHeaders.get(RSocketRequesterMethodArgumentResolver.RSOCKET_REQUESTER_HEADER,
147 RSocketRequester.class);
148 this.clientRSocketRequesters.put(rsocketRequesterKey, rsocketRequester);
149 RSocketConnectedEvent rSocketConnectedEvent =
150 new RSocketConnectedEvent(this, messageHeaders, dataBuffer, rsocketRequester);
151 if (this.applicationEventPublisher != null) {
152 this.applicationEventPublisher.publishEvent(rSocketConnectedEvent);
153 }
154 else {
155 if (logger.isInfoEnabled()) {
156 logger.info("The RSocket has been connected: " + rSocketConnectedEvent);
157 }
158 }
159 }
160
161 }